////
// Copyright 2022 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package osv

import (
	"context"
	"errors"
	"fmt"
	"net/http"
	"strings"
	"time"

	"github.com/guacsec/guac/pkg/assembler/clients/generated"
	"github.com/guacsec/guac/pkg/certifier"
	attestation_vuln "github.com/guacsec/guac/pkg/certifier/attestation/vuln"
	"github.com/guacsec/guac/pkg/certifier/components/root_package"
	"github.com/guacsec/guac/pkg/clients"
	"github.com/guacsec/guac/pkg/events"
	"github.com/guacsec/guac/pkg/handler/processor"
	"github.com/guacsec/guac/pkg/version"

	osv_models "github.com/google/osv-scanner/pkg/models"
	osv_scanner "github.com/google/osv-scanner/pkg/osv"
	attestationv1 "github.com/in-toto/attestation/go/v1"
	jsoniter "github.com/json-iterator/go"
	"golang.org/x/time/rate"
)

var (
	json              = jsoniter.ConfigCompatibleWithStandardLibrary
	rateLimit         = 10000
	rateLimitInterval = 30 * time.Second
)

const (
	URI          string = "osv.dev"
	VERSION      string = "0.0.14"
	INVOC_URI    string = "guac"
	PRODUCER_ID  string = "guacsec/guac"
	OSVCollector string = "osv_certifier"
)

var ErrOSVComponenetTypeMismatch error = errors.New("rootComponent type is not []*root_package.PackageNode")

type osvCertifier struct {
	osvHTTPClient             *http.Client
	withVulnerabilityMetadata bool
}

type CertifierOpts func(*osvCertifier)

func WithVulnerabilityMetadata() CertifierOpts {
	return func(oc *osvCertifier) {
		oc.withVulnerabilityMetadata = true
	}
}

// NewOSVCertificationParser initializes the OSVCertifier
func NewOSVCertificationParser(opts ...CertifierOpts) certifier.Certifier {
	limiter := rate.NewLimiter(rate.Every(rateLimitInterval), rateLimit)
	transport := clients.NewRateLimitedTransport(version.UATransport, limiter)
	client := &http.Client{Transport: transport}

	o := &osvCertifier{
		osvHTTPClient: client,
	}
	for _, opt := range opts {
		opt(o)
	}
	return o
}

// CertifyComponent takes in the root component from the gauc database and does a recursive scan
// to generate vulnerability attestations
func (o *osvCertifier) CertifyComponent(ctx context.Context, rootComponent interface{}, docChannel chan<- *processor.Document) error {
	packageNodes, ok := rootComponent.([]*root_package.PackageNode)
	if !ok {
		return ErrOSVComponenetTypeMismatch
	}

	var purls []string
	for _, node := range packageNodes {
		purls = append(purls, node.Purl)
	}

	if _, err := EvaluateOSVResponse(ctx, o.osvHTTPClient, purls, docChannel, o.withVulnerabilityMetadata); err != nil {
		return fmt.Errorf("could not generate document from OSV results: %w", err)
	}
	return nil
}

// EvaluateOSVResponse takes a list of purls and batch queries OSV for vulnerability information
func EvaluateOSVResponse(ctx context.Context, client *http.Client, purls []string, docChannel chan<- *processor.Document, withVulnerabilityMetadata bool) ([]*processor.Document, error) {
	var query osv_scanner.BatchedQuery
	packMap := map[string]bool{}

	for _, purl := range purls {
		// skip any purls that are generated by GUAC as they will not be found in OSV
		if strings.Contains(purl, "pkg:guac") {
			continue
		}
		if _, ok := packMap[purl]; !ok {
			purlQuery := osv_scanner.MakePURLRequest(purl)
			query.Queries = append(query.Queries, purlQuery)
		}
		packMap[purl] = true
	}

	resp, err := osv_scanner.MakeRequestWithClient(query, client)
	if err != nil {
		return nil, fmt.Errorf("osv.dev batched request failed: %w", err)
	}
	responseMap := make(map[string][]osv_models.Vulnerability, len(packMap))
	if withVulnerabilityMetadata {
		hydratedResp, err := osv_scanner.HydrateWithClient(resp, client)
		if err != nil {
			return nil, fmt.Errorf("hydrating vulnerability with metadata failed: %w", err)
		}
		for i, query := range query.Queries {
			res := hydratedResp.Results[i]
			purl := query.Package.PURL
			responseMap[purl] = res.Vulns
		}
	} else {
		for i, query := range query.Queries {
			res := resp.Results[i]
			purl := query.Package.PURL

			vulns := make([]osv_models.Vulnerability, 0, len(res.Vulns))
			for _, minimal := range res.Vulns {
				vulns = append(vulns, minimalVulnerabilityToVulnerability(&minimal))
			}
			responseMap[purl] = vulns
		}
	}
	return generateDocument(responseMap, docChannel)
}

// generateDocument generated the processor document for ingestion
func generateDocument(responseMap map[string][]osv_models.Vulnerability, docChannel chan<- *processor.Document) ([]*processor.Document, error) {
	var generatedOSVDocs []*processor.Document
	for purl, vulns := range responseMap {
		currentTime := time.Now()
		payload, err := json.Marshal(createAttestation(purl, vulns, currentTime))
		if err != nil {
			return nil, fmt.Errorf("unable to marshal attestation: %w", err)
		}
		doc := &processor.Document{
			Blob:   payload,
			Type:   processor.DocumentITE6Vul,
			Format: processor.FormatJSON,
			SourceInformation: processor.SourceInformation{
				Collector:   OSVCollector,
				Source:      OSVCollector,
				DocumentRef: events.GetDocRef(payload),
			},
		}
		if docChannel != nil {
			docChannel <- doc
		}
		generatedOSVDocs = append(generatedOSVDocs, doc)
	}
	return generatedOSVDocs, nil
}

// createAttestation generated the in-toto vuln attestation
func createAttestation(purl string, vulns []osv_models.Vulnerability, currentTime time.Time) *attestation_vuln.VulnerabilityStatement {
	attestation := &attestation_vuln.VulnerabilityStatement{
		Statement: attestationv1.Statement{
			Type:          attestationv1.StatementTypeUri,
			PredicateType: attestation_vuln.PredicateVuln,
			Subject: []*attestationv1.ResourceDescriptor{
				{
					Uri: purl,
				},
			},
		},
		Predicate: attestation_vuln.VulnerabilityPredicate{
			Scanner: attestation_vuln.Scanner{
				Uri:     URI,
				Version: VERSION,
			},
			Metadata: attestation_vuln.Metadata{
				ScanStartedOn:  &currentTime,
				ScanFinishedOn: &currentTime,
			},
		},
	}

	for _, vuln := range vulns {
		result := attestation_vuln.Result{
			Id: vuln.ID,
		}
		for _, severity := range vuln.Severity {
			var method string
			switch severity.Type {
			case osv_models.SeverityCVSSV2:
				method = string(generated.VulnerabilityScoreTypeCvssv2)
			case osv_models.SeverityCVSSV3:
				method = string(generated.VulnerabilityScoreTypeCvssv3)
			case osv_models.SeverityCVSSV4:
				method = string(generated.VulnerabilityScoreTypeCvssv4)
			default:
				method = string(severity.Type)
			}
			result.Severity = append(result.Severity, attestation_vuln.Severity{
				Method: method,
				Score:  severity.Score,
			})
		}
		attestation.Predicate.Scanner.Result = append(attestation.Predicate.Scanner.Result, result)
	}
	return attestation
}

func minimalVulnerabilityToVulnerability(minimal *osv_scanner.MinimalVulnerability) osv_models.Vulnerability {
	return osv_models.Vulnerability{
		ID: minimal.ID,
	}
}
